# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import sys
from functools import partial, wraps
from typing import Any, Dict

from transformers import AutoConfig
from transformers.dynamic_module_utils import get_class_from_dynamic_module

from swift.llm import TemplateType
from ..constant import MLLMModelType
from ..model_arch import ModelArch
from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal,
                        get_model_tokenizer_with_flash_attn, register_model)
from ..utils import ModelInfo, git_clone_github, safe_snapshot_download


def get_model_tokenizer_llava_llama(model_dir: str,
                                    model_info: ModelInfo,
                                    model_kwargs: Dict[str, Any],
                                    load_model: bool = True,
                                    **kwargs):
    from transformers import LlavaForConditionalGeneration, LlavaConfig

    kwargs['model_config'] = LlavaConfig.from_pretrained(model_dir)
    kwargs['automodel_class'] = kwargs['automodel_class'] or LlavaForConditionalGeneration
    model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs)
    return model, processor


register_model(
    ModelMeta(
        MLLMModelType.llava_llama3_hf,
        [
            ModelGroup([
                Model('AI-ModelScope/llava-llama-3-8b-v1_1-transformers', 'xtuner/llava-llama-3-8b-v1_1-transformers'),
            ]),
        ],
        TemplateType.llava_llama3_hf,
        get_model_tokenizer_llava_llama,
        architectures=['LlavaForConditionalGeneration'],
        model_arch=ModelArch.llava_hf,
        requires=['transformers>=4.36'],
        tags=['vision'],
    ))


def _patch_llava(model):
    if hasattr(model, '__old_generate'):
        return
    generate = model.generate
    model.__old_generate = generate

    @wraps(generate)
    def _new_generate(inputs=None, *args, **kwargs):
        input_ids = kwargs.pop('input_ids', None)
        if inputs is None and input_ids is not None:
            inputs = input_ids
        return generate(inputs, *args, **kwargs)

    model.generate = _new_generate


def get_model_tokenizer_llava_hf(model_dir: str, *args, **kwargs):
    from transformers import LlavaForConditionalGeneration
    kwargs['automodel_class'] = kwargs['automodel_class'] or LlavaForConditionalGeneration
    model, processor = get_model_tokenizer_multimodal(model_dir, *args, **kwargs)
    return model, processor


register_model(
    ModelMeta(
        MLLMModelType.llava1_5_hf,
        [
            ModelGroup([
                Model('llava-hf/llava-1.5-7b-hf', 'llava-hf/llava-1.5-7b-hf'),
                Model('llava-hf/llava-1.5-13b-hf', 'llava-hf/llava-1.5-13b-hf'),
            ]),
        ],
        TemplateType.llava1_5_hf,
        get_model_tokenizer_llava_hf,
        architectures=['LlavaForConditionalGeneration'],
        model_arch=ModelArch.llava_hf,
        requires=['transformers>=4.36'],
        tags=['vision'],
    ))


def get_model_tokenizer_llava_onevision(*args, **kwargs):
    from transformers import LlavaOnevisionForConditionalGeneration
    kwargs['automodel_class'] = kwargs['automodel_class'] or LlavaOnevisionForConditionalGeneration
    return get_model_tokenizer_llava_hf(*args, **kwargs)


register_model(
    ModelMeta(
        MLLMModelType.llava_onevision_hf,
        [
            ModelGroup([
                Model('llava-hf/llava-onevision-qwen2-0.5b-ov-hf', 'llava-hf/llava-onevision-qwen2-0.5b-ov-hf'),
                Model('llava-hf/llava-onevision-qwen2-7b-ov-hf', 'llava-hf/llava-onevision-qwen2-7b-ov-hf'),
                Model('llava-hf/llava-onevision-qwen2-72b-ov-hf', 'llava-hf/llava-onevision-qwen2-72b-ov-hf'),
            ], ),
        ],
        TemplateType.llava_onevision_hf,
        get_model_tokenizer_llava_onevision,
        architectures=['LlavaOnevisionForConditionalGeneration'],
        model_arch=ModelArch.llava_hf,
        requires=['transformers>=4.45'],
        tags=['vision', 'video'],
    ))

register_model(
    ModelMeta(
        MLLMModelType.llava_next_qwen_hf,
        [
            ModelGroup([
                Model('llava-hf/llava-next-72b-hf', 'llava-hf/llava-next-72b-hf'),
                Model('llava-hf/llava-next-110b-hf', 'llava-hf/llava-next-110b-hf'),
            ], ),
        ],
        TemplateType.llava_next_qwen_hf,
        get_model_tokenizer_llava_onevision,
        architectures=['LlavaNextForConditionalGeneration'],
        model_arch=ModelArch.llava_hf,
        requires=['transformers>=4.39'],
        tags=['vision'],
    ))


def get_model_tokenizer_llava_next(*args, **kwargs):
    from transformers import LlavaNextForConditionalGeneration
    kwargs['automodel_class'] = kwargs['automodel_class'] or LlavaNextForConditionalGeneration
    return get_model_tokenizer_llava_hf(*args, **kwargs)


register_model(
    ModelMeta(
        MLLMModelType.llama3_llava_next_hf,
        [
            ModelGroup([
                Model('llava-hf/llama3-llava-next-8b-hf', 'llava-hf/llama3-llava-next-8b-hf'),
            ], ),
        ],
        TemplateType.llama3_llava_next_hf,
        get_model_tokenizer_llava_next,
        architectures=['LlavaNextForConditionalGeneration'],
        model_arch=ModelArch.llava_hf,
        requires=['transformers>=4.39'],
        tags=['vision'],
    ))

register_model(
    ModelMeta(
        MLLMModelType.llava1_6_vicuna_hf,
        [
            ModelGroup([
                Model('llava-hf/llava-v1.6-vicuna-7b-hf', 'llava-hf/llava-v1.6-vicuna-7b-hf'),
                Model('llava-hf/llava-v1.6-vicuna-13b-hf', 'llava-hf/llava-v1.6-vicuna-13b-hf'),
            ], ),
        ],
        TemplateType.llava1_6_vicuna_hf,
        get_model_tokenizer_llava_next,
        architectures=['LlavaNextForConditionalGeneration'],
        model_arch=ModelArch.llava_hf,
        requires=['transformers>=4.39'],
        tags=['vision'],
    ))

register_model(
    ModelMeta(
        MLLMModelType.llava1_6_mistral_hf,
        [
            ModelGroup([
                Model('llava-hf/llava-v1.6-mistral-7b-hf', 'llava-hf/llava-v1.6-mistral-7b-hf'),
            ], ),
        ],
        TemplateType.llava1_6_mistral_hf,
        get_model_tokenizer_llava_next,
        architectures=['LlavaNextForConditionalGeneration'],
        model_arch=ModelArch.llava_hf,
        requires=['transformers>=4.39'],
        tags=['vision'],
    ))

register_model(
    ModelMeta(
        MLLMModelType.llava_llama3_1_hf,
        [
            ModelGroup([
                Model('swift/llava-llama3.1-8b'),
            ], ),
        ],
        TemplateType.llava_llama3_1_hf,
        get_model_tokenizer_llava_next,
        architectures=['LlavaNextForConditionalGeneration'],
        model_arch=ModelArch.llava_hf,
        requires=['transformers>=4.41'],
        tags=['vision'],
    ))


def get_model_tokenizer_llava_next_yi(*args, **kwargs):
    model, tokenizer = get_model_tokenizer_llava_next(*args, **kwargs)
    if model is not None:
        model.config.image_token_index = 64003
    return model, tokenizer


register_model(
    ModelMeta(
        MLLMModelType.llava1_6_yi_hf,
        [
            ModelGroup([
                Model('llava-hf/llava-v1.6-34b-hf', 'llava-hf/llava-v1.6-34b-hf'),
            ], ),
        ],
        TemplateType.llava1_6_yi_hf,
        get_model_tokenizer_llava_next_yi,
        architectures=['LlavaNextForConditionalGeneration'],
        model_arch=ModelArch.llava_hf,
        requires=['transformers>=4.39'],
        tags=['vision'],
    ))


def get_model_tokenizer_llava_next_video(*args, **kwargs):
    from transformers import LlavaNextVideoForConditionalGeneration
    kwargs['automodel_class'] = kwargs['automodel_class'] or LlavaNextVideoForConditionalGeneration
    return get_model_tokenizer_llava_hf(*args, **kwargs)


register_model(
    ModelMeta(
        MLLMModelType.llava_next_video_hf,
        [
            ModelGroup([
                Model('llava-hf/LLaVA-NeXT-Video-7B-DPO-hf', 'llava-hf/LLaVA-NeXT-Video-7B-DPO-hf'),
                Model('llava-hf/LLaVA-NeXT-Video-7B-32K-hf', 'llava-hf/LLaVA-NeXT-Video-7B-32K-hf'),
                Model('llava-hf/LLaVA-NeXT-Video-7B-hf', 'llava-hf/LLaVA-NeXT-Video-7B-hf'),
            ], ),
        ],
        TemplateType.llava_next_video_hf,
        get_model_tokenizer_llava_next_video,
        architectures=['LlavaNextVideoForConditionalGeneration'],
        model_arch=ModelArch.llava_next_video_hf,
        requires=['transformers>=4.42', 'av'],
        tags=['video'],
    ))


def get_model_tokenizer_llava_next_video_yi(*args, **kwargs):
    model, tokenizer = get_model_tokenizer_llava_next_video(*args, **kwargs)
    if model is not None:
        model.config.video_token_index = 64003
        model.config.image_token_index = 64004
    return model, tokenizer


register_model(
    ModelMeta(
        MLLMModelType.llava_next_video_yi_hf,
        [
            ModelGroup([
                Model('llava-hf/LLaVA-NeXT-Video-34B-hf', 'llava-hf/LLaVA-NeXT-Video-34B-hf'),
            ], ),
        ],
        TemplateType.llava_next_video_hf,
        get_model_tokenizer_llava_next_video_yi,
        architectures=['LlavaNextVideoForConditionalGeneration'],
        model_arch=ModelArch.llava_next_video_hf,
        requires=['transformers>=4.42', 'av'],
        tags=['video'],
    ))


def get_model_tokenizer_llava(model_dir: str,
                              model_info: ModelInfo,
                              model_kwargs: Dict[str, Any],
                              load_model: bool = True,
                              **kwargs):
    llm_model_type = kwargs.pop('llm_model_type')
    local_repo_path = kwargs.get('local_repo_path')
    if not local_repo_path:
        if 'next' in llm_model_type:
            repo_path = 'https://github.com/LLaVA-VL/LLaVA-NeXT'
        else:
            repo_path = 'https://github.com/haotian-liu/LLaVA'
        local_repo_path = git_clone_github(repo_path)
    sys.path.append(local_repo_path)

    if llm_model_type == 'mistral':
        from llava.model import LlavaMistralForCausalLM, LlavaMistralConfig
        model_config = LlavaMistralConfig.from_pretrained(model_dir)
        automodel_class = LlavaMistralForCausalLM
    elif 'llama' in llm_model_type:  # llama
        from llava.model import LlavaLlamaForCausalLM, LlavaConfig
        if not hasattr(LlavaLlamaForCausalLM, '__old_forward'):  # Avoid double patching
            forward = LlavaLlamaForCausalLM.forward
            LlavaLlamaForCausalLM.__old_forward = forward

            @wraps(forward)
            def _new_forward(*args, **kwargs):
                kwargs.pop('cache_position', None)
                return forward(*args, **kwargs)

            LlavaLlamaForCausalLM.forward = _new_forward
        model_config = LlavaConfig.from_pretrained(model_dir)
        automodel_class = LlavaLlamaForCausalLM
    else:  # qwen
        from llava.model import LlavaQwenForCausalLM
        automodel_class = LlavaQwenForCausalLM
        model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)

    model_config.mm_vision_tower = safe_snapshot_download('AI-ModelScope/clip-vit-large-patch14-336', check_local=True)
    kwargs['model_config'] = model_config
    kwargs['automodel_class'] = automodel_class
    model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)

    if model is not None:
        model.resize_token_embeddings(len(tokenizer))
        vision_tower = model.get_vision_tower()
        device_map = str(model_kwargs.get('device_map', str(model.device)))
        if not vision_tower.is_loaded:
            vision_tower.load_model(device_map=device_map)
        if not hasattr(model.config, 'max_sequence_length'):
            model.config.max_sequence_length = 2048
        _patch_llava(model)
        tokenizer.image_processor = vision_tower.image_processor
    return model, tokenizer


register_model(
    ModelMeta(
        MLLMModelType.llama3_llava_next,
        [
            ModelGroup([
                Model('AI-ModelScope/llama3-llava-next-8b', 'lmms-lab/llama3-llava-next-8b'),
            ], ),
        ],
        TemplateType.llama3_llava_next,
        partial(get_model_tokenizer_llava, llm_model_type='next_llama'),
        architectures=['LlavaLlamaForCausalLM'],
        model_arch=ModelArch.llava_llama,
        requires=['transformers>=4.42', 'av'],
        tags=['vision'],
    ))

register_model(
    ModelMeta(
        MLLMModelType.llava1_6_mistral,
        [
            ModelGroup([
                Model('AI-ModelScope/llava-v1.6-mistral-7b', 'liuhaotian/llava-v1.6-mistral-7b'),
            ], ),
        ],
        TemplateType.llava1_6_mistral,
        partial(get_model_tokenizer_llava, llm_model_type='mistral'),
        requires=['transformers>=4.34'],
        architectures=['LlavaMistralForCausalLM'],
        model_arch=ModelArch.llava_mistral,
        tags=['vision'],
    ))

register_model(
    ModelMeta(
        MLLMModelType.llava1_6_yi, [
            ModelGroup([
                Model('AI-ModelScope/llava-v1.6-34b', 'liuhaotian/llava-v1.6-34b'),
            ], ),
        ],
        TemplateType.llava1_6_yi,
        partial(get_model_tokenizer_llava, llm_model_type='llama'),
        requires=['transformers>=4.34'],
        architectures=['LlavaLlamaForCausalLM'],
        tags=['vision'],
        model_arch=None))

register_model(
    ModelMeta(
        MLLMModelType.llava_next_qwen, [
            ModelGroup([
                Model('AI-ModelScope/llava-next-72b', 'lmms-lab/llava-next-72b'),
                Model('AI-ModelScope/llava-next-110b', 'lmms-lab/llava-next-110b'),
            ], ),
        ],
        TemplateType.llava_next_qwen,
        partial(get_model_tokenizer_llava, llm_model_type='next_qwen'),
        architectures=['LlavaQwenForCausalLM'],
        requires=['transformers>=4.42', 'av'],
        tags=['vision'],
        model_arch=None))


def get_model_tokenizer_llava_onevision1_5(model_dir, *args, **kwargs):
    model_cls = get_class_from_dynamic_module('modeling_llavaonevision1_5.LLaVAOneVision1_5_ForConditionalGeneration',
                                              model_dir)
    model_cls._no_split_modules = ['LLaVAOneVision1_5_DecoderLayer', 'RiceBlock']
    model, processor = get_model_tokenizer_multimodal(model_dir, *args, **kwargs)
    model.config.vision_start_token_id = 151652
    return model, processor


register_model(
    ModelMeta(
        MLLMModelType.llava_onevision1_5,
        [
            ModelGroup([
                Model('lmms-lab/LLaVA-OneVision-1.5-4B-Instruct', 'lmms-lab/LLaVA-OneVision-1.5-4B-Instruct'),
                Model('lmms-lab/LLaVA-OneVision-1.5-8B-Instruct', 'lmms-lab/LLaVA-OneVision-1.5-8B-Instruct'),
                Model('lmms-lab/LLaVA-OneVision-1.5-4B-Base', 'lmms-lab/LLaVA-OneVision-1.5-4B-Base'),
                Model('lmms-lab/LLaVA-OneVision-1.5-8B-Base', 'lmms-lab/LLaVA-OneVision-1.5-8B-Base'),
            ], ),
        ],
        TemplateType.llava_onevision1_5,
        get_model_tokenizer_llava_onevision1_5,
        architectures=['LLaVAOneVision1_5_ForConditionalGeneration'],
        model_arch=ModelArch.llava_onevision1_5,
        requires=['transformers>=4.53.0', 'qwen_vl_utils'],
        tags=['vision'],
    ))
